# -*- coding: utf-8 -*-
"""
Created on Thu Sep 30 18:42:27 2021

@author: chunr
"""

import numpy as np
import matplotlib.pyplot as plt
import os 
from scipy.linalg import fractional_matrix_power
#moving to the the data directory

#Value to investigate 
data_for_prior_training = 0.5
sigma = 0.1

#Retriving error bound and testing error
kl_penalty = [0.0001,0.0005,0.001,0.005,0.01,0.05,0.1,0.5,1]  
bound_temp_list = []
testing_error_temp_list = [] 
for kk in kl_penalty:
    os.chdir("your directory") 
    if kk == 1:
        kk = 1.0
    bound = np.load(str(kk)+'_'+str(sigma)+'_'+str(data_for_prior_training)+'_cnn_.npy',allow_pickle=True)[0]
    test = np.load(str(kk)+'_'+str(sigma)+'_'+str(data_for_prior_training)+'_cnn_.npy',allow_pickle=True)[1]
    bound_temp_list.append(bound)
    testing_error_temp_list.append(test)

                
    
#Value to investigate 
sigma = 0.1
data_for_prior_training = 0.5
sample_per_class = 75 
NTK_methods = "ntk_init_withdivnothing" 


#Retriving right 
os.chdir("your directory"+NTK_methods) 
rr = np.load('0.03_'+str(sample_per_class)+'_'+str(data_for_prior_training)+'_cnn_.npy',allow_pickle=True)[-2].cpu().numpy()
haha = []
for ww in range(10):
    small_list =[] 
    for k in rr[:,ww]:
        temp_small = [ 1 if k == l else -1 for l in rr[:,ww]]
        small_list.append(temp_small)
    haha.append(np.array(small_list))
final_matrix = haha[0]+haha[1]+haha[2]+haha[3]+haha[4]+haha[5]+haha[6]+haha[7]+haha[8]+haha[9]

os.chdir("your directory"+NTK_methods)
temp_ntk_matrix = np.load('0.03_'+str(sample_per_class)+'_'+str(data_for_prior_training)+'_cnn_.npy',allow_pickle=True)[-1].cpu().numpy()
kl_penalty = [0.0001,0.0005,0.001,0.005,0.01,0.05,0.1,0.5,1] 
temp_list =[]
for i in kl_penalty: 
    temp_ntk_matrix_with_element = fractional_matrix_power(temp_ntk_matrix + (np.identity(len(temp_ntk_matrix)))*(i/sigma),-2)
    temp_align_value = np.trace(final_matrix*temp_ntk_matrix_with_element)
    r_temp_align_value = np.sqrt(temp_align_value/sample_per_class)*(i/sigma)
    l_temp_align_value = (np.trace(final_matrix*fractional_matrix_power(temp_ntk_matrix + (np.identity(len(temp_ntk_matrix)))*(i/sigma),-2)))*(1/(sigma*sample_per_class))
    temp_align_value = l_temp_align_value + r_temp_align_value
    temp_list.append(temp_align_value)
    
    
from matplotlib.pyplot import figure
import scipy.stats as stats

colors = np.array(["red","green","black","orange","purple","lime","cyan","magenta",'navy'])#,'pink']) , "dodgerblue","crimson","teal","peru","violet","seagreen","moccasin","darkred"])
labels = np.array([r"$\lambda = 1x10^{-4} $",r"$\lambda = 5x10^{-4} $",r"$\lambda = 1x10^{-3} $",r"$\lambda = 5x10^{-3} $",r"$\lambda = 1x10^{-2} $",r"$\lambda = 5x10^{-2} $",r"$\lambda = 1x10^{-1} $",r"$\lambda = 5x10^{-1} $",r"$\lambda = 1 $"])#,"50% prior data"]),"55% prior data","60% prior data","65% prior data","70% prior data","75% prior data","80% prior data","85% prior data","90% prior data"])

tau, p_value = stats.kendalltau(temp_list, bound_temp_list)

#figure(figsize=(8, 6), dpi=80)
x = temp_list
y = bound_temp_list
for xx,yy,zz,jj in zip(x,y,colors,labels):
        plt.scatter(xx, yy, c=zz,label=jj)   
plt.ylabel('Error Bound',fontsize=18)
plt.xlabel(r'$\mathcal{PA}$',fontsize=18)
plt.grid(linestyle='-')
plt.ylim([0.24,0.57])
plt.legend(loc='upper left', borderaxespad=0.)
legend = plt.legend()
legend.get_frame().set_edgecolor('black')
#plt.title(r"$ {\frac{1}{\sigma_0}}{Y^T (k(X,X)+ \frac{\lambda}{\sigma_0} I)^{-2} Y} + \frac{\lambda}{\sigma_0} \sqrt{ Y^T(k(X,X)+  \frac{\lambda}{\sigma_0} I)^{-2}Y}$")
plt.title(r"Correlation between $\mathcal{PA}$ and bound"'\n' r"under CNN with different $\lambda$",fontsize=18)
plt.show()
